Speeding up the Highly Adaptive Lasso

Salvador Balkus

Harvard T.H. Chan School of Public Health

What is function estimation?

Consider a function (i.e. conditional mean)

What is function estimation?

If we sample function values with errors…

What is function estimation?

… how can we learn the original form?

Iteratively splitting (decision tree)

Bins : 1

Iteratively splitting (decision tree)

Bins : 2

Iteratively splitting (decision tree)

Bins : 4

Iteratively splitting (decision tree)

Bins : 6

Iteratively splitting (decision tree)

Bins: 8

Iteratively splitting (decision tree)

Bins: 10

Highly Adaptive Lasso (HAL)

λ = 0

Highly Adaptive Lasso (HAL)

λ = 1.6e-5

Highly Adaptive Lasso (HAL)

λ = 3.1e-5

Highly Adaptive Lasso (HAL)

λ = 5.0e-5

Highly Adaptive Lasso (HAL)

λ = 8.7e-5

Highly Adaptive Lasso (HAL)

λ = 0.00017

Why use HAL?

Tree-based models:

  • depend on hyperparameters to stop
  • no guarantee of convergence (?)

HAL:

Function classes

Parametric \(f_\theta(x)\)

Additive \(\sum f_j(x_j)\)

Cadlag (\(M < \infty\))

Lipschitz

No assumptions

Linear regression

GAM / 1-D spline

HAL

Honest forest, BART

Random forest

\(\sqrt{n}\)

\(\sqrt[3]{n}\)

\(\sqrt[4]{n}\)

\(n^{-2 / (d + 2)}\)

Possibly biased

Why cadlag?

Binomial CDF

“cadlag” := right continuous, left limits

Why cadlag?

Binomial CDF + Normal CDF

“cadlag” := right continuous, left limits

  • cadlag preserved by subtraction, multiplication, integration, etc.

\(\sum_{i=1}^n g(x_i) (F(x_{i}) - F(x_{i-1}))\)

\(\approx \int g(x)dF(x)\) is cadlag

Finite sectional variation (\(M < \infty\))

  • section: collection of dimensions \(S \subset \{x_1,\ldots,x_d\}\)
  • sectional variation: \(M = \sum_{S} \int f_S dS\) (“oscillation”)

NOT finite sectional variation:

HAL Formula

Representation theorem (Benkeser and van der Laan 2016; van der Laan 2017): All cadlag functions can be written

\[f = \sum_{S \subset \{1,\ldots,d\}}\int I(u \leq x_S)f_S(du)\]

Approximating via empirical measure yields

\[f_n(x) = \sum_S\sum_{i=1}^n I(x \leq x_{S, i})\beta_{S, i} = \Phi(x)\beta\]

HAL Formula

\[\arg\min_{||\beta||_1 < M} \frac{1}{n}\sum_{i=1}^n L(\Phi(x)\beta)\]

Pros and cons

Pros:

  • covers almost all reasonable functions
  • use glmnet to optimize common loss functions
  • breaks curse of dimensionality
  • necessary statistical properties hold

Con: Must enumerate basis \(\Phi(x)\). How big is it?

\(n\) basis functions \(\times\) \(2^d - 1\) sections =

\(\Huge{O(n\cdot 2^d)}\)

space complexity

Speeding up HAL

  • Current time complexity: \(O(n^2 \cdot 2^d)\)
  • \(2^d \implies\) too slow for even medium-dimensions
  • Need to eliminate

Idea: Randomly sample basis functions to approximate HAL

  • How many do we need?

Speeding up HAL

RandomHAL: Fit \(f_{m,n}\) as follows:

  1. Sample \(m\) basis functions \(\phi_{D, i}(x)\) from \(\Phi(x)\)
    1. Draw dimension of section \(D \sim \text{Binomial}(d, p)\)
    2. Draw knot unit \(i \sim \text{DiscreteUniform}(1, n)\)
  2. Fit lasso over sub-basis \(\Phi_m(x)\)
  3. \(f_{m,n} = \Phi_m(x) \beta_{\text{lasso}}\)

\(\Phi_m(x) \subset \Phi(x)\), but does it converge?

Speeding up HAL

  • Consider risk \(L\), function class \(\mathcal{F}\), and risk minimizer \(f_n = \arg\min_{f \in \mathcal{F}} \frac{1}{n}\sum_{i=1}^n(L(f(x_i)))\).
  • Now consider sequence \(f_{m,n} = \arg\min_{f \in \mathcal{F}_m} \frac{1}{n}\sum_{i=1}^n L(f(x_i))\) in nested subsets

\[\mathcal{F}_1 \subset \mathcal{F}_2 \subset \ldots \subset \mathcal{F}_m \subset \mathcal{F}\]

Clearly RandomHAL satisfies this

Speeding up HAL

Theorem 1: Given nested risk minimization, \[\lVert f - f_{m,n} \rVert \leq \lVert f - f_n \rVert + \lVert f_n - f_{m,n} \rVert = o_P(\min(r_n, r_m))\]

where

  • \(r_n\) is estimation rate of \(f_n\) to \(f\)
  • \(r_m\) is approximation rate of \(f_{m,n}\) to \(f_n\)

If \(r_m\) faster than \(r_n\), then \(f_{m,n}\) preserves rate of \(f_n\)

Speeding up HAL

  • Lasso is fit using coordinate descent (Friedman, Hastie, and Tibshirani 2010)
  • Conceptualize RandomHAL \(f_{m,n}\) as a sequence of \(m\) random coordinate updates
  • (Shalev-Shwartz and Tewari 2011) prove that random coordinate descent satisfies \(\lVert f_{m,n} - f_n\rVert^2 = O(k/m)\)
  • HAL basis has \(k = nd\), so if \(m = o(n \cdot n^{1/2})\), then approximation error is \(o(n^{1/4})\), preserving HAL’s \(o(n^{1/4})\) estimation rate!

But we can do even better!

(Nesterov 2012): if loss is strongly convex (grows faster than quadratic instead of linear), then random coordinate descent satisfies

\[||f_{m,n} - f_n||^2 < C_1\Big(1 - \frac{C_2}{k}\Big)^m\]

so for HAL, actually only need

\[m = o(n\cdot \log(n))\]

How well does it actually work?

Sample simulation: 200 iterations with \(d = 6\)

\[\begin{align*} X_1, \ldots, X_4 &\sim \text{Beta}(\alpha, \beta) \\ X_5 &\sim \text{Ber}(0.5) \\ A &\sim \text{Ber}(\text{expit}(X_1X_5 + \sum_{k=1}^4 (X_k + X_k^2))) \\ Y &\sim \text{Normal}(A + X_5A + 2X_3X_4\\ & + \sum_{k=1}^4 (X_k - X_k^{3/2}), 0.3) \\ \end{align*}\]

How well does it actually work?

One-step ATE, HAL nuisances

Many variables

Does RandomHAL still work when there are too many variables to fit HAL?

Error: from glmnet Fortran code (error code 7777); All used predictors have zero variance

Many variables: the problem

Recall step 1.a: Draw section size \(D \sim Binomial(d, p)\).

\(p = 0.5\)

Many variables: the problem

Recall step 1.a: Draw \(D \sim Binomial(d, p)\).

\(p = 0.05\)

50 confounder simulation

20 replicates with \(p = 0.05\) of

\[\begin{align*} & X_1, \ldots, X_{50} \sim \text{Beta}(2, 2) \\ & A \sim \text{Ber}(0.1 + 0.8I(\sum_{k=1}^{50} 0.2(X_k - X_k^{3/2}) > 1)) \\ & Y \sim \text{Normal}((A+10)\sum_{k=1}^{50} 0.2(X_k - X_k^{3/2}), 0.1) \\ \end{align*}\]

50 confounder simulation

50 confounder simulation

Further questions + ongoing

  • Any \(\ell_2\)-regularized loss function is strongly convex
    • Can “Highly Adaptive Elasticnet” guarantee fast \(n\log(n)\) speed for non-strongly-convex losses?
  • New time complexity: \(O(n^2\log(n))\); still slow!
    • XGBoost bins covariates into histograms to speed up splits
    • Can running HAL over histogram with \(1/\sqrt{n}\) bin width preserve convergence?
  • How to choose optimal \(p\) when \(d\) is large?

Further questions + ongoing

  • Other sampling schemes
    • Avoid throwing out treatment for ATE
  • Undersmoothing / automatic debiasing (Chernozhukov et al. 2021)
    • Choose \(m\) or \(\lambda\) to reduce bias
    • Only need outcome regression, no propensity

Questions?

References

Benkeser, David, and Mark van der Laan. 2016. “The Highly Adaptive Lasso Estimator.” In 2016 IEEE International Conference on Data Science and Advanced Analytics (DSAA), 689–96. IEEE. https://doi.org/10.1109/dsaa.2016.93.
Chernozhukov, Victor, Whitney K. Newey, Victor Quintas-Martinez, and Vasilis Syrgkanis. 2021. “Automatic Debiased Machine Learning via Riesz Regression.” https://doi.org/10.48550/ARXIV.2104.14737.
Fang, Billy, Adityanand Guntuboyina, and Bodhisattva Sen. 2021. “Multivariate Extensions of Isotonic Regression and Total Variation Denoising via Entire Monotonicity and Hardy–Krause Variation.” The Annals of Statistics 49 (2). https://doi.org/10.1214/20-aos1977.
Friedman, Jerome, Trevor Hastie, and Robert Tibshirani. 2010. “Regularization Paths for Generalized Linear Models via Coordinate Descent.” Journal of Statistical Software 33 (1). https://doi.org/10.18637/jss.v033.i01.
Nesterov, Yu. 2012. “Efficiency of Coordinate Descent Methods on Huge-Scale Optimization Problems.” SIAM Journal on Optimization 22 (2): 341–62. https://doi.org/10.1137/100802001.
Richtárik, Peter, and Martin Takáč. 2012. “Iteration Complexity of Randomized Block-Coordinate Descent Methods for Minimizing a Composite Function.” Mathematical Programming 144 (1–2): 1–38. https://doi.org/10.1007/s10107-012-0614-z.
Shalev-Shwartz, Shai, and Ambuj Tewari. 2011. “Stochastic Methods for ℓ1-Regularized Loss Minimization.” Journal of Machine Learning Research 12: 1865–92.
van der Laan, Mark. 2017. “A Generally Efficient Targeted Minimum Loss Based Estimator Based on the Highly Adaptive Lasso.” The International Journal of Biostatistics 13 (2). https://doi.org/10.1515/ijb-2015-0097.
———. 2023. “Higher Order Spline Highly Adaptive Lasso Estimators of Functional Parameters: Pointwise Asymptotic Normality and Uniform Convergence Rates.” arXiv. https://doi.org/10.48550/ARXIV.2301.13354.